import torch
from torch import nn

# 定义一个函数来计算卷积层。它对输入和输出做相应的升维和降维
def comp_conv2d(conv2d, X):
    # (1, 1)代表批量大小和通道数
    X = X.view((1, 1) + X.shape) # 元组相加，相当于前后连接如(1, 1)+(8, 8)=(1, 1, 8, 8)
    Y = conv2d(X)
    return Y.view(Y.shape[2:])  # 排除不关心的前两维：批量和通道

# 注意这里是两侧分别填充1行或列，所以在两侧一共填充2行或列
conv2d = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1)

X = torch.rand(8, 8)
print(comp_conv2d(conv2d, X).shape)

conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)
print(comp_conv2d(conv2d, X).shape)

# 为了表述简洁，当输入的高和宽两侧的填充数分别为ph和pw时，我们称填充为(ph,pw)。
# 特别地，当ph=pw=p时，填充为p。
# 当在高和宽上的步幅分别为sh和sw时，我们称步幅为(sh,sw)。
# 特别地，当sh=sw=s时，步幅为s。
# 在默认情况下，填充为0，步幅为1。
conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))
print(comp_conv2d(conv2d, X).shape)